import numpy as np
import matplotlib.pyplot as plt

##############################################################################################################################
                                             # Import data
##############################################################################################################################

#### No prior

data = np.load("regret_KL_UCB_Transfer_Sim2noprior.npz")
tsaveS2R1   = data["tsave"]     # shape (M,)
RegretS2R1 = data["Regret"]   # shape (R, M)
R, M = RegretS2R1.shape

mean_regretS2R1 = RegretS2R1.mean(axis=0)                     # shape (M,)
sem_regretS2R1  = RegretS2R1.std(axis=0, ddof=1) / np.sqrt(R)  # standard error



#### Prior 1

data = np.load("regret_KL_UCB_Transfer_Sim2prior1.npz")
tsaveS2R2   = data["tsave"]     # shape (M,)
RegretS2R2 = data["Regret"]   # shape (R, M)
R, M = RegretS2R2.shape

mean_regretS2R2 = RegretS2R2.mean(axis=0)                     # shape (M,)
sem_regretS2R2  = RegretS2R2.std(axis=0, ddof=1) / np.sqrt(R)  # standard error



#### Prior 2

data = np.load("regret_KL_UCB_Transfer_Sim2prior2.npz")
tsaveS2R3   = data["tsave"]     # shape (M,)
RegretS2R3 = data["Regret"]   # shape (R, M)
R, M = RegretS2R3.shape

mean_regretS2R3 = RegretS2R3.mean(axis=0)                     # shape (M,)
sem_regretS2R3  = RegretS2R3.std(axis=0, ddof=1) / np.sqrt(R)  # standard error

##############################################################################################################################
                                             # Plot Simulation 2
##############################################################################################################################

plt.figure(figsize=(5,3))


#S2R1
plt.fill_between(tsaveS2R1,
                 mean_regretS2R1 - sem_regretS2R1,
                 mean_regretS2R1 + sem_regretS2R1,
                 alpha=0.3,color = "b"
                 )
plt.plot(tsaveS2R1, mean_regretS2R1, lw=1.5, label="No Prior",color = "b",linestyle='-')

#S2R2
plt.fill_between(tsaveS2R2,
                 mean_regretS2R2 - sem_regretS2R2,
                 mean_regretS2R2 + sem_regretS2R2,
                 alpha=0.3,color = "r"
                 )
plt.plot(tsaveS2R2, mean_regretS2R2, lw=1.5, label="Optimistic Prior",color = "r",linestyle='--')

#S2R3
plt.fill_between(tsaveS2R3,
                 mean_regretS2R3 - sem_regretS2R3,
                 mean_regretS2R3 + sem_regretS2R3,
                 alpha=0.3,color = "g"
                 )
plt.plot(tsaveS2R3, mean_regretS2R3, lw=1.5, label="Pessimistic Prior",color = "g",linestyle='-.')


plt.xscale('log')
plt.xlabel('$T$')
plt.ylabel('$R_T$')
plt.legend()
plt.grid(True, which='both', ls='--', alpha=0.4)
plt.tight_layout()
plt.savefig("plot2ACML.pdf", format="pdf")
plt.show()